import copy

from detectron2.engine import DefaultTrainer, DefaultPredictor, SimpleTrainer
from detectron2.utils.events import EventStorage

from centermask.config import get_cfg
import os
from detectron2.data.datasets import register_coco_instances
from detectron2.data import MetadataCatalog
from detectron2.data import DatasetCatalog
import torch
from detectron2.modeling import build_model
from detectron2.utils.visualizer import ColorMode

import random
import cv2
# 注册数据集
from detectron2.utils.visualizer import Visualizer

from detectron2.data import build_detection_train_loader
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils

# 将预测图片的输出转换成掩膜图的方式
def turn2Mask():
    img_mask = outputs['instances']._fields['pred_masks']
    img_mask = torch.squeeze(img_mask) + 0
    img_mask = torch.where(img_mask > 0, torch.full_like(img_mask, 255), img_mask)
    totalmask = torch.IntTensor(img_mask.size(1), img_mask.size(2)).zero_()
    for i in range(img_mask.size(0)):
        totalmask += img_mask[i]
    img_mask = totalmask.numpy()
    # ret, binary = cv2.threshold(testmask, 0, 255, cv2.THRESH_BINARY)
    cv2.imwrite("cat2-bak.jpg", img_mask)
    cv2.imshow("test", img_mask)
    cv2.waitKey(0)


# todo 预测图片文件夹读入，以及对预测图片坐标信息等的输出
def prediction_imgs(predictions):
    global outputs
    ################################

    # im = cv2.imread(d["file_name"])
    # outputs = predictor(im)
    # vis = Visualizer(im[:, :, ::-1], metadata=fruits_nuts_metadata, scale=0.8,
    #                instance_mode=ColorMode.IMAGE_BW)  # remove the colors of unsegmented pixels  )
    # v = vis.draw_instance_predictions(outputs["instances"].to("cpu"))
    # cv2.imshow("test",v.get_image()[:, :, ::-1])
    # cv2.waitKey(0)
    ################################
    # for d in random.sample(dataset_dicts, 3):
    im = cv2.imread('input1.jpg')
    outputs = predictions(im)
    vis = Visualizer(im[:, :, ::-1], metadata=fruits_nuts_metadata, scale=0.5,
                     instance_mode=ColorMode.IMAGE_BW)  # remove the colors of unsegmented pixels  )
    v = vis.draw_instance_predictions(outputs["instances"])
    cv2.imshow("test", v.get_image()[:, :, ::-1])
    cv2.waitKey(0)
    ################################
    # im = cv2.imread("input1.jpg")
    # predictions = predictor(im)


from detectron2.data import build_detection_train_loader
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils

def mapper(dataset_dict):
    # Implement a mapper, similar to the default DatasetMapper, but with your own customizations
    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    image, transforms = T.apply_transform_gens([T.Resize((800, 800))], image)
    dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

    annos = [
        utils.transform_instance_annotations(obj, transforms, image.shape[:2])
        for obj in dataset_dict.pop("annotations")
        if obj.get("iscrowd", 0) == 0
    ]
    instances = utils.annotations_to_instances(annos, image.shape[:2])
    dataset_dict["instances"] = utils.filter_empty_instances(instances)
    return dataset_dict


# use this dataloader instead of the default

if __name__ == '__main__':
    register_coco_instances("fruits_nuts", {}, "./data/trainval.json", "./data/images")
    # 获取元数据
    fruits_nuts_metadata = MetadataCatalog.get("fruits_nuts")
    dataset_dicts = DatasetCatalog.get("fruits_nuts")

    cfg = get_cfg()
    cfg.merge_from_file("../configs/centermask/centermask_lite_V_19_eSE_FPN_ms_4x.yaml")
    cfg.DATASETS.TRAIN = ("fruits_nuts",)
    # cfg.DATASETS.TEST = ()   # no metrics implemented for this dataset
    cfg.DATALOADER.NUM_WORKERS = 2
    # cfg.MODEL.WEIGHTS = "../configs/centermask/vovnet19_ese_detectron2.pth"  # initialize from model zoo
    cfg.MODEL.WEIGHTS = "../configs/centermask/vovnet19_ese_detectron2.pth"  # initialize from model zoo
    cfg.SOLVER.IMS_PER_BATCH = 1  #原始2->1
    cfg.SOLVER.BASE_LR = 0.02
    cfg.SOLVER.MAX_ITER = 1000    # 300 iterations seems good enough, but you can certainly train longer
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256   # 原始128->64,faster, and good enough for this toy dataset
    cfg.MODEL.FCOS.NUM_CLASSES=3
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
    cfg.MODEL.RETINANET.NUM_CLASSES = 3
    cfg.MODEL.DEVICE = "cpu"

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    # data_loader = build_detection_train_loader(cfg, mapper=mapper)
    if 0==1:

        # 每一个批次都进行计算并且，节省内存不计入啥啥
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load(resume=False)
        trainer.train()  # train 之后怎么是自动保存的、train打印日志
        # for i in range(10):
            # trainer  = SimpleTrainer(cfg)
            # trainer.resume_or_load(resume=False)
            # SimpleTrainer.train()
            # with EventStorage() as storage:
            #     for data, iteration in zip(data_loader, range(1, 7)):
            #         model = build_model(cfg)  # returns a torch.nn.Module
            #         loss_dict = model(data)
            #         losses = sum(loss_dict.values())
            #         #
            #         opt_SGD = torch.optim.SGD(model.parameters(), lr=0.2)
            #         # opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
            #         # opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
            #         # opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
            #         optimizer = opt_SGD
            #         #
            #         optimizer.zero_grad()
            #         losses.backward()
            #         optimizer.step()
            #
            #         print('总损失计算: {} [{}/{} ]'.format(
            #             losses, iteration , i))



    if 1==1:
        # 载入训练权重，进行预测
        cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set the testing threshold for this model
        cfg.DATASETS.TEST = ("fruits_nuts",)
        predictor = DefaultPredictor(cfg)
        # 进行预测相应的图片
        prediction_imgs(predictor)
        # 将输出的mask转换成掩膜图
        turn2Mask()
